{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!guardrails hub install hub://tryolabs/restricttotopic"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# On Topic Validation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This validator checks if a text is related with a topic. Using a list of valid topics (which can include one or many) and optionally a list of invalid topics, it validates that the text's main topic is one of the valid ones. If none of the valid topics are relevant, the topic 'Other' will be considered as the most relevant one and the validator will fail.\n",
    "\n",
    "The validator supports 3 different variants:\n",
    "\n",
    "1. Using an ensemble of Zero-Shot classifier + LLM fallback: if the original classification score is less than 0.5, an LLM is used to classify the main topic. This is the default behavior, setting `disable_classifier = False` and `disable_llm = False`.\n",
    "2. Using just a Zero-Shot classifier to get the main topic (`disable_classifier = False` and `disable_llm = True`).\n",
    "3. Using just an LLM to classify the main topic (`disable_classifier = True` and `disable_llm = False`).\n",
    "\n",
    "To use the LLM, you can pass in a name of any OpenAI ChatCompletion model like `gpt-3.5-turbo` or `gpt-4` as the `llm_callable`, or pass in a callable that handles LLM calls. This callable can use any LLM, that you define. For simplicity purposes, we show here a demo of using OpenAI's gpt-3.5-turbo model.\n",
    "\n",
    "To use the OpenAI API, you have 3 options:\n",
    "\n",
    "1. Set the OPENAI_API_KEY environment variable: os.environ[\"OPENAI_API_KEY\"] = \"[OpenAI_API_KEY]\"\n",
    "2. Set the OPENAI_API_KEY using openai.api_key=\"[OpenAI_API_KEY]\"\n",
    "3. Pass the api_key as a parameter to the parse function"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up a list of valid and invalid topics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "valid_topics = [\"bike\"]\n",
    "invalid_topics = [\"phone\", \"tablet\", \"computer\"]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the target topic"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "text = \"\"\"Introducing the Galaxy Tab S7, a sleek and sophisticated device that seamlessly combines \\\n",
    "cutting-edge technology with unparalleled design. With a stunning 5.1-inch Quad HD Super AMOLED display, \\\n",
    "every detail comes to life in vibrant clarity. The Samsung Galaxy S7 boasts a powerful processor, \\\n",
    "ensuring swift and responsive performance for all your tasks. \\\n",
    "Capture your most cherished moments with the advanced camera system, which delivers stunning photos in any lighting conditions.\"\"\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the device\n",
    "\n",
    "The argument `device` is an ordinal to indicate CPU/GPU support for the Zero-shot classifier. Setting this to -1 (default) will leverage CPU, a positive will run the model on the associated CUDA device id."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = -1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up the model\n",
    "\n",
    "The argument `model` indicates the model that will be used to classify the topic. See a list of all models [here](https://huggingface.co/models?pipeline_tag=zero-shot-classification&sort=trending)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test the validator"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Version 1: Ensemble"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Here, we use the text we defined above as an example llm output (llm_output). This sample text is about the topic 'tablet', which is explicitly mentioned in our 'invalid_topics' list. We expect the validator to fail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/zayd/workspace/guardrails/.venv/lib/python3.9/site-packages/torch/cuda/__init__.py:611: UserWarning: Can't initialize NVML\n",
      "  warnings.warn(\"Can't initialize NVML\")\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation failed for field with errors: Most relevant topic is tablet.\n"
     ]
    }
   ],
   "source": [
    "import guardrails as gd\n",
    "from guardrails.hub import RestrictToTopic\n",
    "from guardrails.errors import ValidationError\n",
    "\n",
    "# Create the Guard with the OnTopic Validator\n",
    "guard = gd.Guard.from_string(\n",
    "    validators=[\n",
    "        RestrictToTopic(\n",
    "            valid_topics=valid_topics,\n",
    "            invalid_topics=invalid_topics,\n",
    "            device=device,\n",
    "            llm_callable=\"gpt-3.5-turbo\",\n",
    "            disable_classifier=False,\n",
    "            disable_llm=False,\n",
    "            on_fail=\"exception\",\n",
    "        )\n",
    "    ],\n",
    ")\n",
    "\n",
    "# Test with a given text\n",
    "try:\n",
    "    guard.parse(\n",
    "        llm_output=text,\n",
    "    )\n",
    "except ValidationError as e:\n",
    "    print(e)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Version 2: Zero-Shot\n",
    "\n",
    "Here, we have disabled the LLM from running at all. We rely totally on what the Zero-Shot classifier outputs. We expect the validator again to fail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation failed for field with errors: Most relevant topic is tablet.\n"
     ]
    }
   ],
   "source": [
    "# Create the Guard with the OnTopic Validator\n",
    "guard = gd.Guard.from_string(\n",
    "    validators=[\n",
    "        RestrictToTopic(\n",
    "            valid_topics=valid_topics,\n",
    "            invalid_topics=invalid_topics,\n",
    "            device=device,\n",
    "            disable_classifier=False,\n",
    "            disable_llm=True,\n",
    "            on_fail=\"exception\",\n",
    "        )\n",
    "    ]\n",
    ")\n",
    "\n",
    "# Test with a given text\n",
    "try:\n",
    "    guard.parse(\n",
    "        llm_output=text,\n",
    "    )\n",
    "except ValidationError as e:\n",
    "    print(e)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Version 3: LLM\n",
    "\n",
    "We finally run the validator using the LLM alone, not as a backup to the zero-shot classifier. This cell expects an OPENAI_API_KEY to be present in as an env var. We again expect this cell to fail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Validation failed for field with errors: Most relevant topic is tablet.\n"
     ]
    }
   ],
   "source": [
    "# Create the Guard with the OnTopic Validator\n",
    "guard = gd.Guard.from_string(\n",
    "    validators=[\n",
    "        RestrictToTopic(\n",
    "            valid_topics=valid_topics,\n",
    "            invalid_topics=invalid_topics,\n",
    "            llm_callable=\"gpt-3.5-turbo\",\n",
    "            disable_classifier=True,\n",
    "            disable_llm=False,\n",
    "            on_fail=\"exception\",\n",
    "        )\n",
    "    ],\n",
    ")\n",
    "\n",
    "# Test with a given text\n",
    "try:\n",
    "    guard.parse(\n",
    "        llm_output=text,\n",
    "    )\n",
    "except ValidationError as e:\n",
    "    print(e)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
